{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Problem: Implement parameter initialization strategies for a CNN model in Pytorch\n",
    "\n",
    "### Problem Statement\n",
    "You are tasked with employing and evaluating a CNN model\\'s parameter initialization strategies in Pytorch. \n",
    "Your goal is to initialize the weights and biases of a vanilla CNN model provided in the problem statement and comment on the implications of each strategy.\n",
    "\n",
    "### Requirements\n",
    "1. **Initialize** weights and biases in the following ways:\n",
    "   - **Zero Initialization**: set the parameters to zero\n",
    "   - **Random Initialization**: sets model parameters to random values drawn from a normal distribution \n",
    "   - **Xavier Initialization** sets them to random values from a normal distribution with **mean=0 and variance=1\\/n**\n",
    "   - **Kaiming He Initialization** initializes to random values from a normal distribution with **mean=0 and variance=2\\/n**\n",
    "2. Train and compute accuracy for each strategy\n",
    "### Constraints\n",
    "- Use the given CNN model and the training and testing helper functions for accuracy computations.\n",
    "- Ensure the model is compatible with the CIFAR-10 dataset, which contains 10 classes.\n",
    "\n",
    "\n",
    "<details>\n",
    "  <summary>💡 Hint</summary>\n",
    "  - Use `torch.nn.init` for weight initialization\n",
    "  <br>\n",
    "  - Resources to read: [All you need is a good init](https://arxiv.org/pdf/1511.06422)\n",
    "</details>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "transform = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n",
    "])\n",
    "\n",
    "train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)\n",
    "train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)\n",
    "\n",
    "test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)\n",
    "test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_test_loop(model, train_loader, test_loader, epochs=10):\n",
    "    model.train()\n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "    optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
    "    \n",
    "    for epoch in range(epochs):\n",
    "        for image, label in train_loader:\n",
    "            pred = model(image)\n",
    "            loss = criterion(pred, label)\n",
    "            \n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()    \n",
    "        print(f\"Training loss at epoch {epoch} = {loss.item()}\")\n",
    "    \n",
    "    model.eval()\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    with torch.no_grad():\n",
    "        for image_test, label_test in test_loader:\n",
    "            pred_test = model(image_test)\n",
    "            _, pred_test_vals = torch.max(pred_test, dim=1)\n",
    "            total += label_test.size(0)\n",
    "            correct += (pred_test_vals == label_test).sum().item()\n",
    "    print(f\"Test Accuracy = {(correct * 100)/total}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class VanillaCNNModel(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)\n",
    "        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)\n",
    "        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)\n",
    "        self.fc1 = nn.Linear(64*16*16, 128)\n",
    "        self.fc2 = nn.Linear(128, 10)\n",
    "        self.relu = nn.ReLU()\n",
    "    \n",
    "    def forward(self, x):\n",
    "        x = self.relu(self.conv1(x))\n",
    "        x = self.pool(self.relu(self.conv2(x)))\n",
    "        x = x.view(x.size(0), -1)\n",
    "        x = self.relu(self.fc1(x))\n",
    "        x = self.fc2(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def config_init(init_type=\"kaiming\"):\n",
    "    # TODO: Add Kaiming initialization according to a CNN or a Linear layer\n",
    "    def kaiming_init(m):\n",
    "        ...\n",
    "            \n",
    "    # TODO: Implement Xavier initialization        \n",
    "    def xavier_init(m):\n",
    "        ...\n",
    "    \n",
    "    # TODO: Initialize weights and biases to zero          \n",
    "    def zeros_init(m):\n",
    "        ...\n",
    "    \n",
    "    # TODO: Initialize weights and biases from a normal distribution          \n",
    "    def random_init(m):\n",
    "        ...\n",
    "    \n",
    "\n",
    "    initializer_dict = {\"kaiming\": kaiming_init,\n",
    "                        \"xavier\": xavier_init,\n",
    "                        \"zeros\": zeros_init,\n",
    "                        \"random\": random_init}\n",
    "    \n",
    "    return initializer_dict.get(init_type)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "for name, model in zip([\"Vanilla\", \"Kaiming\", \"Xavier\", \"Zeros\", \"Random\"], [VanillaCNNModel(),\n",
    "              VanillaCNNModel().apply(config_init(\"kaiming\")),\n",
    "              VanillaCNNModel().apply(config_init(\"xavier\")),\n",
    "              VanillaCNNModel().apply(config_init(\"zeros\")),\n",
    "              VanillaCNNModel().apply(config_init(\"random\"))\n",
    "              ]):\n",
    "    print(f\"_________{name}_______________________\")\n",
    "    train_test_loop(model, train_loader, test_loader)\n",
    "    "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mylogs--5zRa99S-py3.10",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
